/*

 * Licensed to the Apache Software Foundation (ASF) under one

 * or more contributor license agreements.  See the NOTICE file

 * distributed with this work for additional information

 * regarding copyright ownership.  The ASF licenses this file

 * to you under the Apache License, Version 2.0 (the

 * "License"); you may not use this file except in compliance

 * with the License.  You may obtain a copy of the License at

 *

 *     http://www.apache.org/licenses/LICENSE-2.0

 *

 * Unless required by applicable law or agreed to in writing, software

 * distributed under the License is distributed on an "AS IS" BASIS,

 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

 * See the License for the specific language governing permissions and

 * limitations under the License.

 */

package com.bff.gaia.unified.sdk.extensions.sql.impl.transform.agg;



import com.bff.gaia.unified.sdk.coders.CannotProvideCoderException;

import com.bff.gaia.unified.sdk.coders.Coder;

import com.bff.gaia.unified.sdk.coders.CoderRegistry;

import com.bff.gaia.unified.sdk.extensions.sql.impl.UdafImpl;

import com.bff.gaia.unified.sdk.extensions.sql.impl.transform.UnifiedBuiltinAggregations;

import com.bff.gaia.unified.sdk.schemas.Schema;

import com.bff.gaia.unified.sdk.schemas.SchemaCoder;

import com.bff.gaia.unified.sdk.values.Row;

import org.apache.calcite.rel.core.AggregateCall;

import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction;

import com.bff.gaia.unified.sdk.transforms.Combine;



import javax.annotation.Nullable;



/** Wrapper {@link Combine.CombineFn}s for aggregation function calls. */

public class AggregationCombineFnAdapter<T> {

  private abstract static class WrappedCombinerBase<T> extends Combine.CombineFn<T, Object, Object> {

    Combine.CombineFn<T, Object, Object> combineFn;



    WrappedCombinerBase(Combine.CombineFn<T, Object, Object> combineFn) {

      this.combineFn = combineFn;

    }



    @Override

    public Object createAccumulator() {

      return combineFn.createAccumulator();

    }



    @Override

    public Object addInput(Object accumulator, T input) {

      T processedInput = getInput(input);

      return (processedInput == null)

          ? accumulator

          : combineFn.addInput(accumulator, getInput(input));

    }



    @Override

    public Object mergeAccumulators(Iterable<Object> accumulators) {

      return combineFn.mergeAccumulators(accumulators);

    }



    @Override

    public Object extractOutput(Object accumulator) {

      return combineFn.extractOutput(accumulator);

    }



    @Nullable

    abstract T getInput(T input);



    @Override

    public Coder<Object> getAccumulatorCoder(CoderRegistry registry, Coder<T> inputCoder)

        throws CannotProvideCoderException {

      return combineFn.getAccumulatorCoder(registry, inputCoder);

    }

  }



  private static class MultiInputCombiner extends WrappedCombinerBase<Row> {

    MultiInputCombiner(Combine.CombineFn<Row, Object, Object> combineFn) {

      super(combineFn);

    }



    @Override

	Row getInput(Row input) {

      for (Object o : input.getValues()) {

        if (o == null) {

          return null;

        }

      }

      return input;

    }

  }



  private static class SingleInputCombiner extends WrappedCombinerBase<Object> {

    SingleInputCombiner(Combine.CombineFn<Object, Object, Object> combineFn) {

      super(combineFn);

    }



    @Override

    Object getInput(Object input) {

      return input;

    }

  }



  private static class ConstantEmpty extends Combine.CombineFn<Row, Row, Row> {

    private static final Schema EMPTY_SCHEMA = Schema.builder().build();

    private static final Row EMPTY_ROW = Row.withSchema(EMPTY_SCHEMA).build();



    public static final ConstantEmpty INSTANCE = new ConstantEmpty();



    @Override

    public Row createAccumulator() {

      return EMPTY_ROW;

    }



    @Override

    public Row addInput(Row accumulator, Row input) {

      return EMPTY_ROW;

    }



    @Override

    public Row mergeAccumulators(Iterable<Row> accumulators) {

      return EMPTY_ROW;

    }



    @Override

    public Row extractOutput(Row accumulator) {

      return EMPTY_ROW;

    }



    @Override

    public Coder<Row> getAccumulatorCoder(CoderRegistry registry, Coder<Row> inputCoder)

        throws CannotProvideCoderException {

      return SchemaCoder.of(EMPTY_SCHEMA);

    }



    @Override

    public Coder<Row> getDefaultOutputCoder(CoderRegistry registry, Coder<Row> inputCoder) {

      return SchemaCoder.of(EMPTY_SCHEMA);

    }

  }



  /** Creates either a UDAF or a built-in {@link Combine.CombineFn}. */

  public static Combine.CombineFn<?, ?, ?> createCombineFn(

	  AggregateCall call, Schema.Field field, String functionName) {

    if (call.isDistinct()) {

      throw new IllegalArgumentException(

          "Does not support " + call.getAggregation().getName() + " DISTINCT");

    }



    Combine.CombineFn combineFn;

    if (call.getAggregation() instanceof SqlUserDefinedAggFunction) {

      combineFn = getUdafCombineFn(call);

    } else {

      combineFn = UnifiedBuiltinAggregations.create(functionName, field.getType());

    }

    if (call.getArgList().isEmpty()) {

      return new SingleInputCombiner(combineFn);

    } else if (call.getArgList().size() == 1) {

      return new SingleInputCombiner(combineFn);

    } else {

      return new MultiInputCombiner(combineFn);

    }

  }



  public static Combine.CombineFn<Row, ?, Row> createConstantCombineFn() {

    return ConstantEmpty.INSTANCE;

  }



  private static Combine.CombineFn<?, ?, ?> getUdafCombineFn(AggregateCall call) {

    try {

      return ((UdafImpl) ((SqlUserDefinedAggFunction) call.getAggregation()).function)

          .getCombineFn();

    } catch (Exception e) {

      throw new IllegalStateException(e);

    }

  }

}